Building Your First Linear/Quadratic Discriminant Model

Imagine we are detectives in a murder mystery. A local wine producer, Ronald Fisher, was poisoned at a dinner party when somebody replaced the wine in the carafe with wine poisoned with arsenic. 3 other rival wine producers were at the party & are our prime suspects. If we can trace the wine to one of the 3 vineyards, we’ll find our murderer. As luck would have it, we have access to some previous chemical analysis of the wines from each of the vineyards, & we order an analysis of the poisoned carafe at the scene of the crime. Our task is to build a model that will tell us which vineyard the wine with arsenic came from &, therefore, the guilty party.

Loading & Exploring the Data Set

We’ll start by exploring our data set. We have a tibble containing 178 cases & 14 variables of measurements made on various wine bottles.

data(wine, package = 'HDclassif')
wineTib <- as_tibble(wine)
wineTib
## # A tibble: 178 × 14
##    class    V1    V2    V3    V4    V5    V6    V7    V8    V9   V10   V11   V12
##    <int> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
##  1     1  14.2  1.71  2.43  15.6   127  2.8   3.06  0.28  2.29  5.64  1.04  3.92
##  2     1  13.2  1.78  2.14  11.2   100  2.65  2.76  0.26  1.28  4.38  1.05  3.4 
##  3     1  13.2  2.36  2.67  18.6   101  2.8   3.24  0.3   2.81  5.68  1.03  3.17
##  4     1  14.4  1.95  2.5   16.8   113  3.85  3.49  0.24  2.18  7.8   0.86  3.45
##  5     1  13.2  2.59  2.87  21     118  2.8   2.69  0.39  1.82  4.32  1.04  2.93
##  6     1  14.2  1.76  2.45  15.2   112  3.27  3.39  0.34  1.97  6.75  1.05  2.85
##  7     1  14.4  1.87  2.45  14.6    96  2.5   2.52  0.3   1.98  5.25  1.02  3.58
##  8     1  14.1  2.15  2.61  17.6   121  2.6   2.51  0.31  1.25  5.05  1.06  3.58
##  9     1  14.8  1.64  2.17  14      97  2.8   2.98  0.29  1.98  5.2   1.08  2.85
## 10     1  13.9  1.35  2.27  16      98  2.98  3.15  0.22  1.85  7.22  1.01  3.55
## # … with 168 more rows, and 1 more variable: V13 <int>

As you can see right now, we have data that is messy & not well curated. The column names do not mean anything. We could continue working with V1, V2, etc, but it would be hard to keep track of. So instead, we will manually add the variable names. Then we’ll convert the class variable to a factor.

colnames(wineTib) <- c('Class', 'Alco', 'Malic', 'Ash', 'Alk', 'Mag', 'Phe', 'Flav', 'Non_flav', 'Proan', 'Col', 'Hue', 'OD', 'Prol')
wineTib$Class <- as.factor(wineTib$Class)
wineTib
## # A tibble: 178 × 14
##    Class  Alco Malic   Ash   Alk   Mag   Phe  Flav Non_flav Proan   Col   Hue
##    <fct> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl>    <dbl> <dbl> <dbl> <dbl>
##  1 1      14.2  1.71  2.43  15.6   127  2.8   3.06     0.28  2.29  5.64  1.04
##  2 1      13.2  1.78  2.14  11.2   100  2.65  2.76     0.26  1.28  4.38  1.05
##  3 1      13.2  2.36  2.67  18.6   101  2.8   3.24     0.3   2.81  5.68  1.03
##  4 1      14.4  1.95  2.5   16.8   113  3.85  3.49     0.24  2.18  7.8   0.86
##  5 1      13.2  2.59  2.87  21     118  2.8   2.69     0.39  1.82  4.32  1.04
##  6 1      14.2  1.76  2.45  15.2   112  3.27  3.39     0.34  1.97  6.75  1.05
##  7 1      14.4  1.87  2.45  14.6    96  2.5   2.52     0.3   1.98  5.25  1.02
##  8 1      14.1  2.15  2.61  17.6   121  2.6   2.51     0.31  1.25  5.05  1.06
##  9 1      14.8  1.64  2.17  14      97  2.8   2.98     0.29  1.98  5.2   1.08
## 10 1      13.9  1.35  2.27  16      98  2.98  3.15     0.22  1.85  7.22  1.01
## # … with 168 more rows, and 2 more variables: OD <dbl>, Prol <int>

We’ll plot the data to get an idea of how th compounds vary between the vineyards.

wineUntidy <- gather(wineTib, 'Variable', 'Value', -Class)

ggplotly(
  ggplot(wineUntidy, aes(Class, Value)) +
    facet_wrap(~Variable, scales = 'free_y') +
    geom_boxplot() +
    theme_bw()
)

Any data scientist & detective working on the case looking at this data would jump for joy at how many obvious differences there are between wines form the 3 different vineyards. We should easily be able to build a well-performing classification model because the classes are so separable.

Training the Models

Let’s define our task & learner, & build the model as usual. This time, we supply "classif.lda" as the argument to makeLearner() to specify that we’re going to use LDA.

wineTask <- makeClassifTask(data = wineTib, target = 'Class')
lda <- makeLearner('classif.lda')
ldaModel <- train(lda, wineTask)

Let’s extract the model information using the getLearnerModel() function & get DF values for each vase using the predict() function. By printing head(ldaPreds), we can see that the model has learned two DFs, LD1 & LD2, & that the predict() function has indeed returned the values for these functions fore ach vase in our wineTib data set.

ldaModelData <- getLearnerModel(ldaModel)
ldaPreds <- predict(ldaModelData)$x
head(ldaPreds)
##         LD1       LD2
## 1 -4.700244 1.9791383
## 2 -4.301958 1.1704129
## 3 -3.420720 1.4291014
## 4 -4.205754 4.0028715
## 5 -1.509982 0.4512239
## 6 -4.518689 3.2131376

To visualise how the two learned DFs separate the bottle of wine from the 3 vineyards, we’ll plot them against each other. We start by piping the wineTib dataset into a mutate call where we create a new column for each of the DFs. We then pip this mutated tibble into ggplot() & set LD1, LD2, & Class as the x, y, & colour aesthetics, respectively. Finally, we add a geom_point() layer to add dots, & a stat_ellipse() layer to draw 95% confidence ellipses around each class.

ggplotly(
  wineTib %>%
    mutate(LD1 = ldaPreds[, 1], LD2 = ldaPreds[, 2]) %>%
    ggplot(aes(LD1, LD2, colour = Class)) +
    geom_point() + stat_ellipse() +
    theme_bw()
)

We can see that LDA has reduced our 13 predictor variables into just two DFs that do an excellent job of separating the wines from each of the vineyards.

Now, we’ll do the exact same procedure to build a QDA model.

qda <- makeLearner('classif.qda')
qdaModel <- train(qda, wineTask)

Note: Sadly, it isn’t easy to extract the DFs from the implementation of QDA that mlr uses, to plot them as we did for LDA.

Now, let’s cross-validate our LDA & QDA model together to estimate how they will perform on new data.

kFold <- makeResampleDesc(method = 'RepCV', folds = 10, reps = 50, 
                          stratify = TRUE)
ldaCV <- resample(learner = lda, task = wineTask, resampling = kFold, 
                  measures = list(mmce, acc))
qdaCV <- resample(learner = qda, task = wineTask, resampling = kFold,
                  measures = list(mmce, acc))
ldaCV$aggr
## mmce.test.mean  acc.test.mean 
##     0.01110819     0.98889181
qdaCV$aggr
## mmce.test.mean  acc.test.mean 
##    0.008273607    0.991726393

Our LDA model correctly classified 98.8% of wine bottles on average. There isn’t much room for improvement here, but our QDA model managed to correctly flassify 99.2% of the cases. Let’s also look at the confusion matrices.

calculateConfusionMatrix(ldaCV$pred, relative = TRUE)
## Relative confusion matrix (normalized by row/column):
##         predicted
## true     1           2           3           -err.-     
##   1      1e+00/1e+00 3e-04/3e-04 0e+00/0e+00 3e-04      
##   2      8e-03/1e-02 1e+00/1e+00 1e-02/2e-02 2e-02      
##   3      0e+00/0e+00 7e-03/5e-03 1e+00/1e+00 7e-03      
##   -err.-       0.010       0.005       0.021 0.01       
## 
## 
## Absolute confusion matrix:
##         predicted
## true        1    2    3 -err.-
##   1      2949    1    0      1
##   2        30 3470   50     80
##   3         0   18 2382     18
##   -err.-   30   19   50     99
calculateConfusionMatrix(qdaCV$pred, relative = TRUE)
## Relative confusion matrix (normalized by row/column):
##         predicted
## true     1           2           3           -err.-     
##   1      0.996/0.983 0.004/0.004 0.000/0.000 0.004      
##   2      0.014/0.017 0.986/0.993 0.000/0.000 0.014      
##   3      0.000/0.000 0.005/0.003 0.995/1.000 0.005      
##   -err.-       0.017       0.007       0.000 0.008      
## 
## 
## Absolute confusion matrix:
##         predicted
## true        1    2    3 -err.-
##   1      2937   13    0     13
##   2        50 3500    0     50
##   3         0   11 2389     11
##   -err.-   50   24    0     74

Now, detective, the chemical analysis of the poisoned wine is in. Let’s use our QDA model to predict which vineyard it came from.

poisoned <- tibble(Alco = 13, Malic = 2, Ash = 2.2, Alk = 19, Mag = 100, Phe = 2.3, Flav = 2.5, Non_flav = 0.35, Proan = 1.7, Col = 4, Hue = 1.1, OD = 3, Prol = 750)

predict(qdaModel, newdata = poisoned)
## Prediction: 1 observations
## predict.type: response
## threshold: 
## time: 0.00
##   response
## 1        1